{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tf.train.Checkpoint ：变量的保存与恢复 \n",
    "\n",
    "\n",
    "tf.train.Checkpoint 这一强大的变量保存与恢复类，可以使用其 save() 和 restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言，tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#我们首先声明一个 Checkpoint：\n",
    "\n",
    "checkpoint = tf.train.Checkpoint(model=model)\n",
    "\n",
    "\"\"\"\n",
    "这里 tf.train.Checkpoint() 接受的初始化参数比较特殊，是一个 **kwargs 。具体而言，是一系列的键值对，键名可以随意取，\n",
    "值为需要保存的对象。例如，如果我们希望保存一个继承 tf.keras.Model \n",
    "的模型实例 model 和一个继承 tf.train.Optimizer 的优化器 optimizer \n",
    "\"\"\"\n",
    "checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)\n",
    "\n",
    "#当模型训练完成需要保存的时候\n",
    "checkpoint.save(save_path_with_prefix)   #save_path_with_prefix 是保存文件的目录 + 前缀。\n",
    "\n",
    "\"\"\"\n",
    "当在其他地方需要为模型重新载入之前保存的参数时，需要再次实例化一个 checkpoint，同时保持键名的一致。\n",
    "再调用 checkpoint 的 restore 方法。就像下面这样：\n",
    "\"\"\"\n",
    "model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型\n",
    "checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”\n",
    "checkpoint.restore(save_path_with_prefix_and_index)\n",
    "\n",
    "\"\"\"\n",
    "当保存了多个文件时，我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) \n",
    "这个辅助函数返回目录下最近一次 checkpoint 的文件名。例如如果 save 目录下有 model.ckpt-1.index 到 model.ckpt-10.index 的 10 个保存文件，\n",
    "tf.train.latest_checkpoint('./save') 即返回 ./save/model.ckpt-10 。\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 恢复与保存变量的典型代码框架如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# train.py 模型训练阶段\n",
    "\n",
    "model = MyModel()\n",
    "# 实例化Checkpoint，指定保存对象为model（如果需要保存Optimizer的参数也可加入）\n",
    "checkpoint = tf.train.Checkpoint(myModel=model)\n",
    "# ...（模型训练代码）\n",
    "# 模型训练完毕后将参数保存到文件（也可以在模型训练过程中每隔一段时间就保存一次）\n",
    "checkpoint.save('./save/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test.py 模型使用阶段\n",
    "\n",
    "model = MyModel()\n",
    "checkpoint = tf.train.Checkpoint(myModel=model)             # 实例化Checkpoint，指定恢复对象为model\n",
    "checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数\n",
    "# 模型使用代码"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型变量的保存和载入："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import argparse\n",
    "from zh.model.mnist.mlp import MLP\n",
    "from zh.model.utils import MNISTLoader\n",
    "\n",
    "parser = argparse.ArgumentParser(description='Process some integers.')\n",
    "parser.add_argument('--mode', default='train', help='train or test')\n",
    "parser.add_argument('--num_epochs', default=1)\n",
    "parser.add_argument('--batch_size', default=50)\n",
    "parser.add_argument('--learning_rate', default=0.001)\n",
    "args = parser.parse_args()\n",
    "data_loader = MNISTLoader()\n",
    "\n",
    "\n",
    "def train():\n",
    "    model = MLP()\n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)\n",
    "    num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)\n",
    "    checkpoint = tf.train.Checkpoint(myAwesomeModel=model)      # 实例化Checkpoint，设置保存对象为model\n",
    "    for batch_index in range(1, num_batches+1):                 \n",
    "        X, y = data_loader.get_batch(args.batch_size)\n",
    "        with tf.GradientTape() as tape:\n",
    "            y_pred = model(X)\n",
    "            loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)\n",
    "            loss = tf.reduce_mean(loss)\n",
    "            print(\"batch %d: loss %f\" % (batch_index, loss.numpy()))\n",
    "        grads = tape.gradient(loss, model.variables)\n",
    "        optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))\n",
    "        if batch_index % 100 == 0:                              # 每隔100个Batch保存一次\n",
    "            path = checkpoint.save('./save/model.ckpt')         # 保存模型参数到文件\n",
    "            print(\"model saved to %s\" % path)\n",
    "\n",
    "\n",
    "def test():\n",
    "    model_to_be_restored = MLP()\n",
    "    # 实例化Checkpoint，设置恢复对象为新建立的模型model_to_be_restored\n",
    "    checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      \n",
    "    checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数\n",
    "    y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)\n",
    "    print(\"test accuracy: %f\" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    if args.mode == 'train':\n",
    "        train()\n",
    "    if args.mode == 'test':   # 在命令行参数中加入 --mode=test 并再次运行代码\n",
    "        test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用 tf.train.CheckpointManager 删除旧的 Checkpoint 以及自定义文件编号\n",
    "\n",
    "\n",
    "\n",
    "directory 参数为文件保存的路径， checkpoint_name 为文件名前缀（不提供则默认为 ckpt ）， max_to_keep 为保留的 Checkpoint 数目。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = tf.train.Checkpoint(model=model)\n",
    "manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import argparse\n",
    "from zh.model.mnist.mlp import MLP\n",
    "from zh.model.utils import MNISTLoader\n",
    "\n",
    "parser = argparse.ArgumentParser(description='Process some integers.')\n",
    "parser.add_argument('--mode', default='train', help='train or test')\n",
    "parser.add_argument('--num_epochs', default=1)\n",
    "parser.add_argument('--batch_size', default=50)\n",
    "parser.add_argument('--learning_rate', default=0.001)\n",
    "args = parser.parse_args()\n",
    "data_loader = MNISTLoader()\n",
    "\n",
    "\n",
    "def train():\n",
    "    model = MLP()\n",
    "    optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)\n",
    "    num_batches = int(data_loader.num_train_data // args.batch_size * args.num_epochs)\n",
    "    checkpoint = tf.train.Checkpoint(myAwesomeModel=model)      \n",
    "    # 使用tf.train.CheckpointManager管理Checkpoint\n",
    "    manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)\n",
    "    for batch_index in range(1, num_batches):\n",
    "        X, y = data_loader.get_batch(args.batch_size)\n",
    "        with tf.GradientTape() as tape:\n",
    "            y_pred = model(X)\n",
    "            loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)\n",
    "            loss = tf.reduce_mean(loss)\n",
    "            print(\"batch %d: loss %f\" % (batch_index, loss.numpy()))\n",
    "        grads = tape.gradient(loss, model.variables)\n",
    "        optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))\n",
    "        if batch_index % 100 == 0:\n",
    "            # 使用CheckpointManager保存模型参数到文件并自定义编号\n",
    "            path = manager.save(checkpoint_number=batch_index)         \n",
    "            print(\"model saved to %s\" % path)\n",
    "\n",
    "\n",
    "def test():\n",
    "    model_to_be_restored = MLP()\n",
    "    checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)      \n",
    "    checkpoint.restore(tf.train.latest_checkpoint('./save'))\n",
    "    y_pred = np.argmax(model_to_be_restored.predict(data_loader.test_data), axis=-1)\n",
    "    print(\"test accuracy: %f\" % (sum(y_pred == data_loader.test_label) / data_loader.num_test_data))\n",
    "\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    if args.mode == 'train':\n",
    "        train()\n",
    "    if args.mode == 'test':\n",
    "        test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TensorBoard：训练过程可视化\n",
    "\n",
    "\n",
    "首先在代码目录下建立一个文件夹（如 ./tensorboard ）存放 TensorBoard 的记录文件，并在代码中实例化一个记录器：\n",
    "\n",
    "\n",
    "TensorBoard 的使用有以下注意事项：\n",
    "\n",
    "如果需要重新训练，需要删除掉记录文件夹内的信息并重启 TensorBoard（或者建立一个新的记录文件夹并开启 TensorBoard， --logdir 参数设置为新建立的文件夹）；\n",
    "\n",
    "记录文件夹目录保持全英文。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_writer = tf.summary.create_file_writer('./tensorboard')     # 参数为记录文件所保存的目录\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，当需要记录训练过程中的参数时，通过 with 语句指定希望使用的记录器，并对需要记录的参数（一般是 scalar）运行 tf.summary.scalar(name, tensor, step=batch_index) ，即可将训练过程中参数在 step 时候的值记录下来。这里的 step 参数可根据自己的需要自行制定，一般可设置为当前训练过程中的 batch 序号。整体框架如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 开始模型训练\n",
    "for batch_index in range(num_batches):\n",
    "    # ...（训练代码，当前batch的损失值放入变量loss中）\n",
    "    with summary_writer.as_default():                               # 希望使用的记录器\n",
    "        tf.summary.scalar(\"loss\", loss, step=batch_index)   #每运行一次 tf.summary.scalar() ，记录器就会向记录文件中写入一条记录\n",
    "        tf.summary.scalar(\"MyScalar\", my_scalar, step=batch_index)  # 还可以添加其他自定义的变量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard --logdir=./tensorboard    #在代码目录打开终端"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 查看 Graph 和 Profile 信息\n",
    "\n",
    "除此以外，我们可以在训练时使用 tf.summary.trace_on 开启 Trace，此时 TensorFlow 会将训练时的大量信息（如计算图的结构，每个操作所耗费的时间等）记录下来。在训练完成后，使用 tf.summary.trace_export 将记录结果输出到文件。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.summary.trace_on(graph=True, profiler=True)  # 开启Trace，可以记录图结构和profile信息\n",
    "# 进行训练\n",
    "with summary_writer.as_default():\n",
    "    tf.summary.trace_export(name=\"model_trace\", step=0, profiler_outdir=log_dir)    # 保存Trace信息到文件\n",
    "    \n",
    "    \n",
    "\"\"\"\n",
    "之后，我们就可以在 TensorBoard 中选择 “Profile”，以时间轴的方式查看各操作的耗时情况。\n",
    "如果使用了 tf.function 建立了计算图，也可以点击 “Graphs” 查看图结构。\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from zh.model.mnist.mlp import MLP\n",
    "from zh.model.utils import MNISTLoader\n",
    "\n",
    "num_batches = 1000\n",
    "batch_size = 50\n",
    "learning_rate = 0.001\n",
    "log_dir = 'tensorboard'\n",
    "model = MLP()\n",
    "data_loader = MNISTLoader()\n",
    "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
    "summary_writer = tf.summary.create_file_writer(log_dir)     # 实例化记录器\n",
    "tf.summary.trace_on(profiler=True)  # 开启Trace（可选）\n",
    "for batch_index in range(num_batches):\n",
    "    X, y = data_loader.get_batch(batch_size)\n",
    "    with tf.GradientTape() as tape:\n",
    "        y_pred = model(X)\n",
    "        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)\n",
    "        loss = tf.reduce_mean(loss)\n",
    "        print(\"batch %d: loss %f\" % (batch_index, loss.numpy()))\n",
    "        with summary_writer.as_default():                           # 指定记录器\n",
    "            tf.summary.scalar(\"loss\", loss, step=batch_index)       # 将当前损失函数的值写入记录器\n",
    "    grads = tape.gradient(loss, model.variables)\n",
    "    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))\n",
    "with summary_writer.as_default():\n",
    "    tf.summary.trace_export(name=\"model_trace\", step=0, profiler_outdir=log_dir)    # 保存Trace信息到文件（可选）"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
